/*
 * Parallelising JVM Compiler
 *
 * Copyright 2010 Peter Calvert, University of Cambridge
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License
 */

package cuda;

import exceptions.UnsupportedInstruction;

import analysis.dataflow.SimpleUsed;

import graph.BasicBlock;
import graph.ClassNode;
import graph.Kernel;
import graph.Method;
import graph.Modifier;
import graph.Type;

import graph.instructions.*;

import graph.state.Field;
import graph.state.State;
import graph.state.Variable;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;

import java.util.HashSet;
import java.util.Set;

import org.apache.log4j.Logger;

/**
 * Top level class for export of code to CUDA. Used both to invoke other
 * exporter classes (<code>cuda.BlockExporter</code>) and the NVCC compiler
 * itself.
 */
public class CUDAExporter {
  /**
   * Set of methods that have been exported to CUDA.
   */
  static private Set<Method> exported = new HashSet<Method>();

  /**
   * Set of statics which have been defined for CUDA.
   */
  static private Set<Field> statics = new HashSet<Field>();

  /**
   * Set of classes which have been defined for CUDA.
   */
  static private Set<ClassNode> classes = new HashSet<ClassNode>();

  /**
   * Print stream for source exports.
   */
  static private PrintStream out;

  /**
   * File for source output.
   */
  static private File source;

  /**
   * Library name.
   */
  static private String libraryName;

  /**
   * File for library output.
   */
  static private File sharedLibrary;

  /**
   * Location of NVCC compiler.
   */
  static private File nvcc;

  /**
   * JDK home directory.
   */
  static private File jdk;

  /**
   * Parallel Includes (defaults to same as JAR)
   */
  static private File includes;

  /**
   * Sets the required paths for the compiler step.
   *
   * @param cuda     CUDA home directory.
   * @param jdk      JDK home directory.
   * @param includes Includes directory.
   */
  public static void setSystem(File cuda, File jdk, File includes) {
    CUDAExporter.nvcc     = new File(cuda, "bin" + File.separator + "nvcc");
    CUDAExporter.jdk      = jdk;
    CUDAExporter.includes = includes;
  }

  /**
   * Sets the destinations for CUDA exports. These can be changed, but some
   * methods may then be reexported, as exports are independent.
   *
   * @param dir     Directory in which outputs should occur.
   * @param library Name of the destination library.
   * @param compile Whether the source should be compiled or not.
   */
  public static void setDestination(File dir, String library, boolean compile) throws IOException {
    libraryName   = library;

    // Create directory.
    dir.mkdirs();

    // Set names.
    if(compile) {
      source        = File.createTempFile("src", ".cu", dir);
      sharedLibrary = new File(dir, System.mapLibraryName(library));

      source.deleteOnExit();
    } else {
      source        = new File(dir, System.mapLibraryName(library) + ".cu");
      sharedLibrary = null;
    }

    // Create Beautified Printstream.
    out = new Beautifier(source);

    // File Header
    out.println("// Autogenerated CUDA code.");
    out.println("#include <parallel.h>");
    out.println();

    // Clear export sets.
    exported.clear();
    statics.clear();
    classes.clear();
  }

  /**
   * Invokes compiler on current source file (if library output is specified).
   * 
   * @throws IOException
   */
  public static void compile() throws IOException {
    if(sharedLibrary != null) {
      Process build = new ProcessBuilder(
        nvcc.getAbsolutePath(),          // Run NVCC
        "--shared",                      // Produce shared code.
        "-o",                            // Output to ...
        sharedLibrary.getAbsolutePath(), // ... required place.
        "-Xcompiler",                    // We want ...
        "-fPIC",                         // ... position independent code.
        /*"-g", "-Xcompiler", "-finstrument-functions",
        "-Xlinker", "--version-script=export.map",*/
        "-arch", "sm_13",
        "-I",                            // Need to include JDK headers ...
        new File(jdk, "include").getAbsolutePath(),
        "-I",                            // ... and the Linux specific ones.
        new File(jdk, "include" + File.separator +
        System.getProperty("os.name", "other").toLowerCase()).getAbsolutePath(),
        "-I",                            // ... and parallel headers.
        new File(jdk, "include").getAbsolutePath(),
        "-I",                            // ... and the Linux specific ones.
        includes.getAbsolutePath(),
        source.getAbsolutePath()
        //,"hrprof.o"
      ).redirectErrorStream(true).start();

      BufferedReader output = new BufferedReader(
        new InputStreamReader(build.getInputStream())
      );

      String s;
      while ((s = output.readLine()) != null) {
        if(!s.trim().equals("")) {
          Logger.getLogger("cuda.nvcc").debug(s);
        }
      }
    }
  }

  /**
   * Exports a kernel to the current CUDA file.
   *
   * @param kernel Kernel to export.
   */
  public static void export(Kernel kernel) throws UnsupportedInstruction {
    // Check not already exported.
    if(exported.contains(kernel))
      return;

    // Calculate state used by kernel.
    SimpleUsed used = new SimpleUsed(kernel.getImplementation());

    // Define statics used in kernel (and functions called by it).
    for(Field f : used.getStatics()) {
      if(!statics.contains(f)) {
        out.println("__constant__ " + Helper.getType(f.getType()) + " " +
                    Helper.getName(f) + ";");
        statics.add(f);
      }
    }

    // Define classes used in kernel (and functions called by it).
    for(ClassNode c : used.getClasses()) {
      if(!classes.contains(c)) {
        Helper.defineClass(c, out);
        classes.add(c);
      }
    }

    // Create temporary printstream.
    ByteArrayOutputStream bytes = new ByteArrayOutputStream();
    PrintStream temp = new Beautifier(bytes);

    // Head
    Helper.kernelStart(kernel, temp);

    // Declare local variables.
    final Set<Variable> varDeclare = new HashSet<Variable>(used.getVariables());

    varDeclare.removeAll(kernel.getParameterVariables());

    for(State s : varDeclare) {
      temp.println(Helper.getType(s.getType()) + " " + Helper.getName(s) + ";");
    }

    // Declare stack restoration variables.
    for(int i = 0; i < used.getStackCount(); i++) {
      for(Type t : used.getStackTypes(i)) {
        temp.println(
	  Helper.getType(t) + " s" + Helper.getName(new Variable(i, t)) + ";"
        );
      }
    }

    // Export Body
    BlockExporter ke = new BlockExporter(temp);
    kernel.getImplementation().accept(ke);

    // Tail
    Helper.kernelEnd(kernel, temp);

    // Launcher
    Helper.launcher(kernel, temp);

    // Commit export to file.
    try {
      out.write(bytes.toByteArray());
    } catch (IOException ex) {
      throw new RuntimeException("Unexpected file error in exporting kernel.");
    }

    // Mark kernel as being native.
    kernel.getModifiers().add(Modifier.NATIVE);

    // Add to exported set.
    exported.add(kernel);
  }

  /**
   * Exports a standard method to the current CUDA file as a device function.
   *
   * @param method Method to export.
   */
  public static void export(Method method) throws UnsupportedInstruction {
    // Check not already exported.
    if(exported.contains(method))
      return;

    // Create temporary printstream.
    ByteArrayOutputStream bytes = new ByteArrayOutputStream();
    PrintStream temp = new Beautifier(bytes);

    // Head
    Helper.methodStart(method, temp);

    // Declare local variables.
    SimpleUsed used = new SimpleUsed(method.getImplementation());
    final Set<Variable> varDeclare = new HashSet<Variable>(used.getVariables());

    varDeclare.removeAll(method.getParameterVariables());

    for(State s : varDeclare) {
      temp.println(Helper.getType(s.getType()) + " " + Helper.getName(s) + ";");
    }

    // Declare stack restoration variables.
    for(int i = 0; i < used.getStackCount(); i++) {
      for(Type t : used.getStackTypes(i)) {
        temp.println(
	  Helper.getType(t) + " s" + Helper.getName(new Variable(i, t)) + ";"
        );
      }
    }

    // Export Body
    BlockExporter ke = new BlockExporter(temp);
    method.getImplementation().accept(ke);

    // Tail
    Helper.methodEnd(method, temp);

    // Commit export to file.
    try {
      out.write(bytes.toByteArray());
    } catch (IOException ex) {
      throw new RuntimeException("Unexpected file error in exporting method.");
    }

    // Add to exported set.
    exported.add(method);
  }

  public static void addLoad(ClassNode clazz) {
    Method clinit = clazz.getMethod("<clinit>", "()V");

    BasicBlock bb = new BasicBlock();
    ClassNode system = ClassNode.getClass("java/lang/System");

    bb.getStateful().add(
      new Call(
        new Producer[] {new Constant(libraryName)},
        system.getMethod("loadLibrary", "(Ljava/lang/String;)V"),
        Call.Sort.STATIC
      )
    );

    if(clinit.getImplementation() == null) {
      bb.setBranch(new Return());
      clinit.setImplementation(bb);
    } else {
      bb.setNext(clinit.getImplementation());
      clinit.setImplementation(bb);
    }
  }
}
